"""
This script uses radiomics features to identify scar in non-Gd cine images
All radiomics features have been prepared (off-line) and stored in excel/csv file
"""
import os
import numpy as np
import random
import pandas as pd
import time
start = time.time()

import matplotlib.pyplot as plt
init_rs = 2021
random.seed(init_rs)

from sklearn.feature_selection import SelectFromModel
from sklearn.linear_model import LogisticRegression, LassoCV
from sklearn.ensemble import GradientBoostingClassifier, StackingClassifier
from sklearn.svm import SVC
from sklearn import metrics
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

gpus = '0'
num_gpus = 1
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
# Disable warnings... jamming console
def warn(*args, **kwargs):
    pass
import warnings
warnings.warn = warn
#################################################
# READ radiomics features from excel sheet
testing_fn ='../radiomics_data/all_features_dummy_fname.csv'
test_table = pd.read_csv(testing_fn)
pats_ids = np.asarray(test_table['pat_id'])
num_pats = len(set(pats_ids))

for cross_val in range(0,5):
    ### READ radiomics features from excel sheets for training and validation dataset
    training_fn  = '../radiomics_data/all_features_train_XVAL_' + str(cross_val) + '.csv'
    validation_fn= '../radiomics_data/all_features_valid_XVAL_'+ str(cross_val) + '.csv'
    train_table = pd.read_csv(training_fn)
    valid_table = pd.read_csv(validation_fn)
    develop_table = pd.concat([train_table,valid_table])
    # Determine the important features and remove all other features from input data
    y_dev = develop_table["slice_label"].copy()
    X_dev = develop_table.iloc[:,23:-2].copy() # first 23 columns are diagnostics data generated by pyradiomics: exclude
    for col in X_dev.columns: # normalize features
        X_dev[col] = (X_dev[col]-develop_table[col].min())/(develop_table[col].max()-develop_table[col].min()+0.000001)
    req_num_rad_feats = 5
    lasso = LassoCV(random_state=init_rs, cv=5, verbose=False) # use LASSO for feature selection
    sfm = SelectFromModel(lasso, threshold=-np.inf, max_features = req_num_rad_feats).fit(X_dev, y_dev)
    idx = [i[0] for i in np.argwhere(sfm.get_support() == True)]
    all_feats_names = X_dev.columns
    selected_feats_names = [all_feats_names[i] for i in idx]
    print("Most important features are:")
    print(*selected_feats_names, sep="\n")

    ## Prediction MODEL
    clf = LogisticRegression(class_weight='balanced', penalty='l1', solver='liblinear', verbose= False)
    clf.fit(X_dev[selected_feats_names],y_dev)

    X_test = test_table[selected_feats_names].copy()
    y_test = test_table['slice_label'].copy()
    for col in X_test.columns: # normalize features based on development dataset (not testing because in practice (prospective testing) we do not have all testing patients)
        X_test[col] = (X_test[col] - develop_table[col].min()) / (develop_table[col].max() - develop_table[col].min() + 0.000001)

    y_true, y_pred = y_test, clf.predict_proba(X_test)
    y_pred = y_pred[:, 1]
    yy = y_pred > 0.5 # Decision at operating point 50%
    print('###### Per-Slice Cross-validation ' + str(cross_val) + '   ########################')
    print(metrics.classification_report(y_true, yy))
    auc = metrics.roc_auc_score(y_true,y_pred)
    print('Per-Slice Area-Under-Curve = ' + str(auc))
    print('############################################################')
    fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred)
    plt.figure(1)
    plt.plot([0, 1], [0, 1], 'k--')
    plt.plot(fpr, tpr, label='Per Slice (area = {:.3f} )'.format(auc))
    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    plt.title('ROC curve - Slice LGE Predictions')
    plt.legend(loc='best')
    plt.show()

    """ ######################################## Per-Patient Analysis ########################################"""
    # printing FN and TP for few thresholds
    Err = np.zeros([100, 1])
    pFP = np.zeros([100, 1])  # p for patient
    pFN = np.zeros([100, 1])
    pTP = np.zeros([100, 1])
    pTN = np.zeros([100, 1])
    pSN = np.zeros([100, 1])
    pSP = np.zeros([100, 1])
    pREC = np.zeros([100, 1])
    pPRE = np.zeros([100, 1])
    pACC = np.zeros([100, 1])

    cnt = 0
    cutoff_num_sl = 2
    cutoff_th = 50
    for thresh in range(100):
        yy = y_pred > thresh/100
        start = 0
        subj_pred = np.zeros([num_pats, 1])
        subj_real = np.zeros([num_pats, 1])
        for i in range(num_pats):
            num_slices = len(np.argwhere(pats_ids == pats_ids[start]))
            subj_res = yy[start:start + num_slices]  # predicted slice label (LGE yes/no)
            subj_label = y_true[start:start + num_slices]  # reference slice label (LGE yes/no)
            if np.sum(subj_res) >= cutoff_num_sl:  # 2 slices at least have LGE+
                subj_pred[i] = 1
            if np.sum(subj_label) > 0:  # ground truth, 1 slice at least have LGE+
                subj_real[i] = [1]
            start = start + num_slices

        Err[cnt] = np.sum(np.abs(subj_pred - subj_real))
        pFN[cnt] = len(np.argwhere((subj_pred - subj_real) == -1))
        pFP[cnt] = len(np.argwhere((subj_pred - subj_real) == 1))
        pTN[cnt] = len(np.argwhere(subj_real + subj_pred == 0))
        pTP[cnt] = len(np.argwhere((subj_pred + subj_real) == 2))

        pSN[cnt] = pTP[cnt] / (pTP[cnt] + pFN[cnt])
        pSP[cnt] = pTN[cnt] / (pTN[cnt] + pFP[cnt])
        pREC[cnt] = pTN[cnt] / (pTN[cnt] + pFN[cnt]+0.00001)
        pPRE[cnt] = pTP[cnt] / (pTP[cnt] + pFP[cnt]+0.00001)
        pACC[cnt] = (pTP[cnt] + pTN[cnt]) / (pTP[cnt] + pFP[cnt] + pTN[cnt] + pFN[cnt])

        if cnt == cutoff_th:  # Operating Point 50% threshold?
            print([x[0] for x in subj_pred])
            print('Patient Sens     : ', pSN[cnt])
            print('Patient Spec     : ', pSP[cnt])
            print('Patient Recall   : ', pREC[cnt])
            print('Patient Precision: ', pPRE[cnt])
            print('Patient Accuracy : ', pACC[cnt])

        cnt += 1

    AUC = metrics.auc(1 - pSP, pSN)
    print('Per-Patient AUC', AUC)
    plt.figure(2)
    plt.plot([0, 1], [0, 1], 'k--')
    plt.plot(1 - pSP, pSN, label='Per-Patient (area = {:.3f} )'.format(AUC))
    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    plt.title('ROC curve - Patient LGE Prediction')
    plt.legend(loc='best')
    plt.show()